Master's thesis case study 3: Bandit's with stopping¶
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import numpy
import torch
from adaptive_nof1 import *
from adaptive_nof1.policies import *
from adaptive_nof1.helpers import *
from adaptive_nof1.inference import *
from adaptive_nof1.metrics import *
from matplotlib import pyplot as plt
import seaborn
from adaptive_nof1.patient_explorer import show_patient_explorer
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
In [3]:
# Setup generic n-of-1 parameters
block_length = 5
max_length = 10 * block_length
number_of_actions = 2
number_of_patients = 100
In [4]:
# Scenarios
class NormalModel(Model):
def __init__(self, patient_id, mean, variance):
self.rng = numpy.random.default_rng(patient_id)
self.mean = mean
self.variance = variance
self.patient_id = patient_id
def multivariate_normal_distribution(debug_data):
cov = torch.diag_embed(torch.tensor(numpy.sqrt(self.variance)))
return torch.distributions.MultivariateNormal(torch.tensor(self.mean), cov)
def generate_context(self, history):
return {}
@property
def additional_config(self):
return {"expectations_of_interventions": self.mean}
@property
def number_of_interventions(self):
return len(self.mean)
def observe_outcome(self, action, context):
treatment_index = action["treatment"]
return {"outcome": self.rng.normal(self.mean[treatment_index], numpy.sqrt(self.variance[treatment_index]))}
def __str__(self):
return f"NormalModel({self.mean, self.variance})"
generating_scenario_I = lambda patient_id: NormalModel(patient_id, mean=[0, 0], variance=[1,1])
generating_scenario_II = lambda patient_id: NormalModel(patient_id, mean=[1, 0], variance=[1,1])
generating_scenario_III = lambda patient_id: NormalModel(patient_id, mean=[2, 0], variance=[1,1])
In [5]:
# Inference Model
inference_model = lambda: NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
# Stopping Time
ALPHA_STOPPING = 0.01
def alpha_stopping_time(history, context):
model = NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
model.update_posterior(history, number_of_actions)
probabilities = model.approximate_max_probabilities(number_of_actions, context)
return 1 - max(probabilities) < ALPHA_STOPPING
In [6]:
# Policies
fixed_policy = StoppingPolicy(
policy = BlockPolicy(
block_length = block_length,
internal_policy = FixedPolicy(
number_of_actions=2,
inference_model = inference_model(),
),
),
stopping_time = alpha_stopping_time,
)
explore_then_commit = StoppingPolicy(
policy= BlockPolicy(
block_length = block_length,
internal_policy = ExploreThenCommit(
number_of_actions=2,
exploration_length=4,
block_length = block_length,
inference_model = inference_model(),
),
),
stopping_time = alpha_stopping_time,
)
thompson_sampling_policy = StoppingPolicy(
policy = BlockPolicy(
block_length = block_length,
internal_policy = ThompsonSampling(
inference_model=inference_model(),
number_of_actions=2,
),
),
stopping_time = alpha_stopping_time,
)
ucb_policy = StoppingPolicy(
policy = BlockPolicy(
block_length = block_length,
internal_policy = UpperConfidenceBound(
inference_model=inference_model(),
number_of_actions=2,
epsilon=0.05,
),
),
stopping_time = alpha_stopping_time,
)
In [7]:
# Full crossover study
study_designs = {
"n_patients": [number_of_patients],
"policy": [fixed_policy, explore_then_commit, thompson_sampling_policy, ucb_policy],
"model_from_patient_id": [
generating_scenario_I, generating_scenario_II, generating_scenario_III,
]
}
configurations = generate_configuration_cross_product(study_designs)
In [8]:
ENABLE_SIMULATION = False
if ENABLE_SIMULATION:
print("Simulation was enabled")
else:
print("Simulation was disabled")
Simulation was enabled
In [9]:
if ENABLE_SIMULATION:
calculated_series, config_to_simulation_data = simulate_configurations(
configurations, max_length
)
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
In [10]:
if ENABLE_SIMULATION:
write_to_disk("data/2024-02-11-mt_case_study_3_data.json", [calculated_series, config_to_simulation_data])
else:
calculated_series, config_to_simulation_data = load_from_disk("data/2024-02-11-mt_case_study_3_data.json")
In [23]:
# Todo: make the output table in a way that we chose the maximum index
def debug_data_to_torch_distribution(debug_data):
mean = debug_data["mean"]
# + the true variance of 1
standard_deviation = numpy.sqrt(numpy.array(debug_data["variance"]) + 1)
cov = torch.diag_embed(torch.tensor(standard_deviation))
return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)
def data_to_true_distribution(data):
mean = data.additional_config["expectations_of_interventions"]
cov = torch.eye(len(mean))
return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)
metrics = [
SimpleRegretWithMean(),
BestArmIdentification(),
CumulativeRegret(),
Length(),
KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution),
]
model_mapping = {
"NormalModel(([0, 0], [1, 1]))": "I",
"NormalModel(([1, 0], [1, 1]))": "II",
"NormalModel(([2, 0], [1, 1]))": "III",
}
policy_mapping = {
"StoppingPolicy(BlockPolicy(FixedPolicy))": "Fixed",
"StoppingPolicy(BlockPolicy(ThompsonSampling(NormalKnownVariance(0, 1, 1))))": "TS",
"StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))": "UCB",
"StoppingPolicy(BlockPolicy(ExploreThenCommit(4,NormalKnownVariance(0, 1, 1))))": "ETC",
}
df = SeriesOfSimulationsData.score_data(
[s["result"] for s in calculated_series], metrics, {"model": lambda x: model_mapping[x], "policy": lambda x: policy_mapping[x]}
)
df = df.reset_index(drop=True)
max_t_indices = df.groupby(["policy", "metric", "model", "patient_id"])["t"].idxmax()
filtered_df = df.iloc[max_t_indices]
filtered_df = filtered_df.reset_index(drop=True)
groupby_columns = ["model", "policy"]
pivoted_df = filtered_df.pivot(
index=["model", "policy", "patient_id"],
columns="metric",
values="score",
)
table = pivoted_df.groupby(groupby_columns).agg(['mean', 'std'])
policy_ordering = ["Fixed", "ETC", "UCB", "TS"]
# Convert the 'policy' column in the MultiIndex to a Categorical type with the specified order
table = table.reset_index()
table['policy'] = pd.Categorical(table['policy'], categories=policy_ordering, ordered=True)
# Sort the DataFrame first by 'model' then by the now-ordered 'policy'
sorted_table = table.sort_values(by=['model', 'policy']).set_index(groupby_columns)[["Cumulative Regret (outcome)", "KL Divergence", "Simple Regret With Mean", "Length", "Best Arm Identification With Mean"]]
sorted_table.index.names = ["S.", "Policy"]
sorted_table
Out[23]:
| metric | Cumulative Regret (outcome) | KL Divergence | Simple Regret With Mean | Length | Best Arm Identification With Mean | ||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| mean | std | mean | std | mean | std | mean | std | mean | std | ||
| S. | Policy | ||||||||||
| I | Fixed | -1.015252 | 7.228124 | 0.058222 | 0.108264 | 0.0 | 0.000000 | 48.11 | 8.292500 | 0.54 | 0.500908 |
| ETC | -0.906465 | 7.043759 | 0.068086 | 0.104436 | 0.0 | 0.000000 | 47.87 | 8.621860 | 0.49 | 0.502418 | |
| UCB | -0.990027 | 7.111509 | 0.093519 | 0.123903 | 0.0 | 0.000000 | 47.94 | 8.338568 | 0.52 | 0.502117 | |
| TS | -1.03192 | 7.280245 | 0.073209 | 0.099991 | 0.0 | 0.000000 | 48.48 | 7.134324 | 0.51 | 0.502418 | |
| II | Fixed | -13.437944 | 8.339200 | 0.064499 | 0.055955 | 0.0 | 0.000000 | 24.04 | 13.325695 | 1.0 | 0.000000 |
| ETC | -18.781947 | 14.631094 | 0.070053 | 0.063430 | 0.0 | 0.000000 | 26.43 | 15.462078 | 1.0 | 0.000000 | |
| UCB | -33.334714 | 21.230854 | 0.149162 | 0.231549 | 0.02 | 0.140705 | 37.12 | 17.352716 | 0.98 | 0.140705 | |
| TS | -24.481239 | 17.566796 | 0.153106 | 0.242918 | 0.01 | 0.100000 | 33.44 | 15.109245 | 0.99 | 0.100000 | |
| III | Fixed | -12.886026 | 5.586778 | 0.181561 | 0.183543 | 0.0 | 0.000000 | 10.41 | 4.109843 | 1.0 | 0.000000 |
| ETC | -13.261922 | 6.508062 | 0.1809 | 0.182769 | 0.0 | 0.000000 | 10.59 | 4.408658 | 1.0 | 0.000000 | |
| UCB | -46.015805 | 42.648110 | 0.872675 | 0.977566 | 0.28 | 0.697470 | 26.03 | 19.479829 | 0.86 | 0.348735 | |
| TS | -42.058965 | 42.416344 | 0.899014 | 0.987336 | 0.26 | 0.675995 | 25.34 | 18.321054 | 0.87 | 0.337998 | |
In [12]:
with open('mt_resources/7-stopping/01-table-part-1.tex', 'w') as file:
str = sorted_table[["Cumulative Regret (outcome)", "KL Divergence", "Simple Regret With Mean"]].style.format(precision=1).to_latex(hrules=True)
print(str)
file.write(str)
\begin{tabular}{lllrlrlr}
\toprule
& metric & \multicolumn{2}{r}{Cumulative Regret (outcome)} & \multicolumn{2}{r}{KL Divergence} & \multicolumn{2}{r}{Simple Regret With Mean} \\
& & mean & std & mean & std & mean & std \\
S. & Policy & & & & & & \\
\midrule
\multirow[c]{4}{*}{I} & Fixed & -1.0 & 7.2 & 0.1 & 0.1 & 0.0 & 0.0 \\
& ETC & -0.9 & 7.0 & 0.1 & 0.1 & 0.0 & 0.0 \\
& UCB & -1.0 & 7.1 & 0.1 & 0.1 & 0.0 & 0.0 \\
& TS & -1.0 & 7.3 & 0.1 & 0.1 & 0.0 & 0.0 \\
\multirow[c]{4}{*}{II} & Fixed & -13.4 & 8.3 & 0.1 & 0.1 & 0.0 & 0.0 \\
& ETC & -18.8 & 14.6 & 0.1 & 0.1 & 0.0 & 0.0 \\
& UCB & -33.3 & 21.2 & 0.1 & 0.2 & 0.0 & 0.1 \\
& TS & -24.5 & 17.6 & 0.2 & 0.2 & 0.0 & 0.1 \\
\multirow[c]{4}{*}{III} & Fixed & -12.9 & 5.6 & 0.2 & 0.2 & 0.0 & 0.0 \\
& ETC & -13.3 & 6.5 & 0.2 & 0.2 & 0.0 & 0.0 \\
& UCB & -46.0 & 42.6 & 0.9 & 1.0 & 0.3 & 0.7 \\
& TS & -42.1 & 42.4 & 0.9 & 1.0 & 0.3 & 0.7 \\
\bottomrule
\end{tabular}
In [34]:
with open('mt_resources/7-stopping/01-table-part-2.tex', 'w') as file:
str = sorted_table[["Length", "Best Arm Identification With Mean"]].style.format(precision=2).to_latex(hrules=True)
print(str)
file.write(str)
\begin{tabular}{lllrlr}
\toprule
& metric & \multicolumn{2}{r}{Length} & \multicolumn{2}{r}{Best Arm Identification With Mean} \\
& & mean & std & mean & std \\
S. & Policy & & & & \\
\midrule
\multirow[c]{4}{*}{I} & Fixed & 48.11 & 8.29 & 0.54 & 0.50 \\
& ETC & 47.87 & 8.62 & 0.49 & 0.50 \\
& UCB & 47.94 & 8.34 & 0.52 & 0.50 \\
& TS & 48.48 & 7.13 & 0.51 & 0.50 \\
\multirow[c]{4}{*}{II} & Fixed & 24.04 & 13.33 & 1.00 & 0.00 \\
& ETC & 26.43 & 15.46 & 1.00 & 0.00 \\
& UCB & 37.12 & 17.35 & 0.98 & 0.14 \\
& TS & 33.44 & 15.11 & 0.99 & 0.10 \\
\multirow[c]{4}{*}{III} & Fixed & 10.41 & 4.11 & 1.00 & 0.00 \\
& ETC & 10.59 & 4.41 & 1.00 & 0.00 \\
& UCB & 26.03 & 19.48 & 0.86 & 0.35 \\
& TS & 25.34 & 18.32 & 0.87 & 0.34 \\
\bottomrule
\end{tabular}
In [14]:
def rename_df(df):
df["policy_#_metric_#_model_p"] = df["policy"].apply(lambda x: policy_mapping[x])
df['policy'] = pd.Categorical(df['policy_#_metric_#_model_p'], categories=policy_ordering, ordered=True)
return df
SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[
CumulativeRegret(),
],
legend_position=(0.02,0.3),
process_df = rename_df,
)
plt.ylabel('Regret')
plt.savefig("mt_resources/7-stopping/01_cumulative_regret.pdf", bbox_inches="tight")
In [15]:
SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[
SimpleRegretWithMean(),
],
legend_position=(0.8,1.0),
process_df = rename_df,
)
plt.ylabel('Simple Regret')
plt.savefig("mt_resources/7-stopping/01_simple_regret.pdf", bbox_inches="tight")
In [16]:
SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[
KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution)
],
legend_position=(0.8,1.0),
process_df = rename_df,
)
plt.ylabel('KL Divergence')
plt.savefig("mt_resources/7-stopping/01-kl-divergence.pdf", bbox_inches="tight")
In [26]:
df = SeriesOfSimulationsData.score_data(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[ IsStopped() ],
)
groupby_df_sum = rename_df(df).groupby(["policy", "model", "t"]).sum()
ax = seaborn.lineplot(
data=groupby_df_sum,
x="t",
y="score",
hue="policy",
# units="patient_id",
#estimator=numpy.median,
#errorbar=lambda x: (numpy.quantile(x, 0.25), numpy.quantile(x, 0.75)),
)
plt.ylabel("Number of patients")
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/7-stopping/01_is_stopped.pdf", bbox_inches="tight")
/var/folders/2g/v44yvb1n6sdgnp5mwbh8_qgc0000gn/T/ipykernel_95227/3657439798.py:5: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning. groupby_df_sum = rename_df(df).groupby(["policy", "model", "t"]).sum()
In [18]:
plot_allocations_for_calculated_series(calculated_series)
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value layout_plot = gridplot( /opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value layout_plot = gridplot(
Out[18]: